#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import asyncio
import json
import math

from datetime import datetime, timedelta
from multiprocessing.util import Finalize

from celery import current_app, schedules
from celery.beat import ScheduleEntry, Scheduler
from celery.signals import beat_init
from celery.utils.log import get_logger
from redis.asyncio.lock import Lock
from sqlalchemy import select
from sqlalchemy.exc import DatabaseError, InterfaceError

from backend.app.task.enums import PeriodType, TaskSchedulerType
from backend.app.task.model.scheduler import TaskScheduler
from backend.app.task.schema.scheduler import CreateTaskSchedulerParam
from backend.app.task.utils.tzcrontab import TzAwareCrontab, crontab_verify
from backend.common.exception import errors
from backend.core.conf import settings
from backend.database.db import async_db_session
from backend.database.redis import redis_client
from backend.utils._await import run_await
from backend.utils.serializers import select_as_dict
from backend.utils.timezone import timezone

# 此计划程序必须比常规的 5 分钟更频繁地唤醒，因为它需要考虑对计划的外部更改
DEFAULT_MAX_INTERVAL = 5  # seconds

# 计划锁时长，避免重复创建
DEFAULT_MAX_LOCK_TIMEOUT = 300  # seconds

# 锁检测周期，应小于计划锁时长
DEFAULT_LOCK_INTERVAL = 60  # seconds

# Copied from:
# https://github.com/andymccurdy/redis-py/blob/master/redis/lock.py#L33
# Changes:
#     The second line from the bottom: The original Lua script intends
#     to extend time to (lock remaining time + additional time); while
#     the script here extend time to an expected expiration time.
# KEYS[1] - lock name
# ARGS[1] - token
# ARGS[2] - additional milliseconds
# return 1 if the locks time was extended, otherwise 0
LUA_EXTEND_TO_SCRIPT = """
    local token = redis.call('get', KEYS[1])
    if not token or token ~= ARGV[1] then
        return 0
    end
    local expiration = redis.call('pttl', KEYS[1])
    if not expiration then
        expiration = 0
    end
    if expiration < 0 then
        return 0
    end
    redis.call('pexpire', KEYS[1], ARGV[2])
    return 1
"""

logger = get_logger('fba.schedulers')


class ModelEntry(ScheduleEntry):
    """任务调度实体"""

    def __init__(self, model: TaskScheduler, app=None):
        super().__init__(
            app=app or current_app._get_current_object(),
            name=model.name,
            task=model.task,
        )
        try:
            if (
                model.type == TaskSchedulerType.INTERVAL
                and model.interval_every is not None
                and model.interval_period is not None
            ):
                self.schedule = schedules.schedule(timedelta(**{model.interval_period: model.interval_every}))
            elif model.type == TaskSchedulerType.CRONTAB and model.crontab is not None:
                crontab_split = model.crontab.split(' ')
                self.schedule = TzAwareCrontab(
                    minute=crontab_split[0],
                    hour=crontab_split[1],
                    day_of_week=crontab_split[2],
                    day_of_month=crontab_split[3],
                    month_of_year=crontab_split[4],
                )
            else:
                raise errors.NotFoundError(msg=f'{self.name} 计划为空！')
            # logger.debug('Schedule: {}'.format(self.schedule))
        except Exception as e:
            logger.error(f'禁用计划为空的任务 {self.name}，详情：{e}')
            asyncio.create_task(self._disable(model))

        try:
            self.args = json.loads(model.args) if model.args else None
            self.kwargs = json.loads(model.kwargs) if model.kwargs else None
        except ValueError as exc:
            logger.error(f'禁用参数错误的任务：{self.name}；error: {str(exc)}')
            asyncio.create_task(self._disable(model))

        self.options = {}
        for option in ['queue', 'exchange', 'routing_key']:
            value = getattr(model, option)
            if value is None:
                continue
            self.options[option] = value

        expires = getattr(model, 'expires_', None)
        if expires:
            if isinstance(expires, int):
                self.options['expires'] = expires
            elif isinstance(expires, datetime):
                self.options['expires'] = timezone.from_datetime(expires)

        if not model.last_run_time:
            model.last_run_time = timezone.now()
            if model.start_time:
                model.last_run_time = timezone.from_datetime(model.start_time) - timedelta(days=365)

        self.last_run_at = timezone.from_datetime(model.last_run_time)
        self.options['periodic_task_name'] = model.name
        self.model = model

    async def _disable(self, model: TaskScheduler) -> None:
        """禁用任务"""
        model.no_changes = True
        self.model.enabled = self.enabled = model.enabled = False
        async with async_db_session.begin():
            setattr(model, 'enabled', False)

    def is_due(self) -> tuple[bool, int | float]:
        """任务到期状态"""
        if not self.model.enabled:
            # 重新启用时延迟 5 秒
            return schedules.schedstate(is_due=False, next=5)

        # 仅在 'start_time' 之后运行
        if self.model.start_time is not None:
            now = timezone.now()
            start_time = timezone.from_datetime(self.model.start_time)
            if now < start_time:
                delay = math.ceil((start_time - now).total_seconds())
                return schedules.schedstate(is_due=False, next=delay)

        # 一次性任务
        if self.model.one_off and self.model.enabled and self.model.total_run_count > 0:
            self.model.enabled = False
            self.model.total_run_count = 0
            self.model.no_changes = False
            save_fields = ('enabled',)
            run_await(self.save)(save_fields)
            return schedules.schedstate(is_due=False, next=1000000000)  # 高延迟，避免重新检查

        return self.schedule.is_due(self.last_run_at)

    def __next__(self):
        self.model.last_run_time = timezone.now()
        self.model.total_run_count += 1
        self.model.no_changes = True
        return self.__class__(self.model)

    next = __next__

    async def save(self, fields: tuple = ()):
        """
        保存任务状态字段

        :param fields: 要保存的其他字段
        :return:
        """
        async with async_db_session.begin() as db:
            stmt = select(TaskScheduler).where(TaskScheduler.id == self.model.id).with_for_update()
            query = await db.execute(stmt)
            task = query.scalars().first()
            if task:
                for field in ['last_run_time', 'total_run_count', 'no_changes']:
                    setattr(task, field, getattr(self.model, field))
                for field in fields:
                    setattr(task, field, getattr(self.model, field))
            else:
                logger.warning(f'任务 {self.model.name} 不存在，跳过更新')

    @classmethod
    async def from_entry(cls, name, app=None, **entry):
        """保存或更新本地任务调度"""
        async with async_db_session.begin() as db:
            stmt = select(TaskScheduler).where(TaskScheduler.name == name)
            query = await db.execute(stmt)
            task = query.scalars().first()
            temp = await cls._unpack_fields(name, **entry)
            if not task:
                task = TaskScheduler(**temp)
                db.add(task)
            else:
                for key, value in temp.items():
                    setattr(task, key, value)
            res = cls(task, app=app)
            return res

    @staticmethod
    async def to_model_schedule(name: str, task: str, schedule: schedules.schedule | TzAwareCrontab):
        schedule = schedules.maybe_schedule(schedule)

        async with async_db_session() as db:
            if isinstance(schedule, schedules.schedule):
                every = max(schedule.run_every.total_seconds(), 0)
                spec = {
                    'name': name,
                    'type': TaskSchedulerType.INTERVAL.value,
                    'interval_every': every,
                    'interval_period': PeriodType.SECONDS.value,
                }
                stmt = select(TaskScheduler).filter_by(**spec)
                query = await db.execute(stmt)
                obj = query.scalars().first()
                if not obj:
                    obj = TaskScheduler(**CreateTaskSchedulerParam(task=task, **spec).model_dump())
            elif isinstance(schedule, schedules.crontab):
                crontab = f'{schedule._orig_minute} {schedule._orig_hour} {schedule._orig_day_of_week} {schedule._orig_day_of_month} {schedule._orig_month_of_year}'  # noqa: E501
                crontab_verify(crontab)
                spec = {
                    'name': name,
                    'type': TaskSchedulerType.CRONTAB.value,
                    'crontab': crontab,
                }
                stmt = select(TaskScheduler).filter_by(**spec)
                query = await db.execute(stmt)
                obj = query.scalars().first()
                if not obj:
                    obj = TaskScheduler(**CreateTaskSchedulerParam(task=task, **spec).model_dump())
            else:
                raise errors.NotFoundError(msg=f'暂不支持的计划类型：{schedule}')

            return obj

    @classmethod
    async def _unpack_fields(
        cls,
        name: str,
        task: str,
        schedule: schedules.schedule | TzAwareCrontab,
        args: tuple | None = None,
        kwargs: dict | None = None,
        options: dict = None,
        **entry,
    ) -> dict:
        model_schedule = await cls.to_model_schedule(name, task, schedule)
        model_dict = select_as_dict(model_schedule)
        for k in ['id', 'created_time', 'updated_time']:
            try:
                del model_dict[k]
            except KeyError:
                continue
        model_dict.update(
            args=json.dumps(args) if args else None,
            kwargs=json.dumps(kwargs) if kwargs else None,
            **cls._unpack_options(**options or {}),
            **entry,
        )
        return model_dict

    @classmethod
    def _unpack_options(
        cls,
        queue: str = None,
        exchange: str = None,
        routing_key: str = None,
        start_time: datetime = None,
        expires: datetime = None,
        expire_seconds: int = None,
        one_off: bool = False,
    ) -> dict:
        data = {
            'queue': queue,
            'exchange': exchange,
            'routing_key': routing_key,
            'start_time': start_time,
            'expire_time': expires,
            'expire_seconds': expire_seconds,
            'one_off': one_off,
        }
        if expires:
            if isinstance(expires, int):
                data['expire_seconds'] = expires
            elif isinstance(expires, timedelta):
                data['expire_time'] = timezone.now() + expires
        return data


class DatabaseScheduler(Scheduler):
    """数据库调度程序"""

    Entry = ModelEntry

    _schedule = None
    _last_update = None
    _initial_read = True
    _heap_invalidated = False

    lock: Lock | None = None
    lock_key = f'{settings.CELERY_REDIS_PREFIX}:beat_lock'

    def __init__(self, *args, **kwargs):
        self.app = kwargs['app']
        self._dirty = set()
        super().__init__(*args, **kwargs)
        self._finalize = Finalize(self, self.sync, exitpriority=5)
        self.max_interval = kwargs.get('max_interval') or self.app.conf.beat_max_loop_interval or DEFAULT_MAX_INTERVAL

    def setup_schedule(self):
        """重写父函数"""
        logger.info('setup_schedule')
        tasks = self.schedule
        self.install_default_entries(tasks)
        self.update_from_dict(self.app.conf.beat_schedule)

    async def get_all_task_schedulers(self):
        """获取所有任务调度"""
        async with async_db_session() as db:
            logger.debug('DatabaseScheduler: Fetching database schedule')
            stmt = select(TaskScheduler).where(TaskScheduler.enabled == 1)
            query = await db.execute(stmt)
            tasks = query.scalars().all()
            s = {}
            for task in tasks:
                s[task.name] = self.Entry(task, app=self.app)
            return s

    def schedule_changed(self) -> bool:
        """任务调度变更状态"""
        now = timezone.now()
        last_update = run_await(redis_client.get)(f'{settings.CELERY_REDIS_PREFIX}:last_update')
        if not last_update:
            run_await(redis_client.set)(f'{settings.CELERY_REDIS_PREFIX}:last_update', timezone.to_str(now))
            return False

        last, ts = self._last_update, timezone.from_str(last_update)
        try:
            if ts and ts > (last if last else ts):
                return True
        finally:
            self._last_update = now

    def reserve(self, entry):
        """重写父函数"""
        new_entry = next(entry)
        # 需要按名称存储条目，因为条目可能会发生变化
        self._dirty.add(new_entry.name)
        return new_entry

    def close(self):
        """重写父函数"""
        if self.lock:
            logger.info('beat: Releasing lock')
            if run_await(self.lock.owned)():
                run_await(self.lock.release)()
            self.lock = None

        super().close()

    def sync(self):
        """重写父函数"""
        _tried = set()
        _failed = set()
        try:
            while self._dirty:
                name = self._dirty.pop()
                try:
                    tasks = self.schedule
                    run_await(tasks[name].save)()
                    logger.debug(f'保存任务 {name} 最新状态到数据库')
                    _tried.add(name)
                except KeyError as e:
                    logger.error(f'保存任务 {name} 最新状态失败：{e} ')
                    _failed.add(name)
        except DatabaseError as e:
            logger.exception('同步时出现数据库错误: %r', e)
        except InterfaceError as e:
            logger.warning(f'DatabaseScheduler InterfaceError：{str(e)}，等待下次调用时重试...')
        finally:
            # 请稍后重试（仅针对失败的）
            self._dirty |= _failed

    def update_from_dict(self, beat_dict: dict):
        """重写父函数"""
        s = {}
        for name, entry_fields in beat_dict.items():
            try:
                entry = run_await(self.Entry.from_entry)(name, app=self.app, **entry_fields)
                if entry.model.enabled:
                    s[name] = entry
            except Exception as e:
                logger.error(f'添加任务 {name} 到数据库失败')
                raise e

        tasks = self.schedule
        tasks.update(s)

    def install_default_entries(self, data):
        """重写父函数"""
        entries = {}
        if self.app.conf.result_expires:
            entries.setdefault(
                'celery.backend_cleanup',
                {
                    'task': 'celery.backend_cleanup',
                    'schedule': schedules.crontab('0', '4', '*'),
                    'options': {'expire_seconds': 12 * 3600},
                },
            )
        self.update_from_dict(entries)

    def schedules_equal(self, *args, **kwargs):
        """重写父函数"""
        if self._heap_invalidated:
            self._heap_invalidated = False
            return False
        return super().schedules_equal(*args, **kwargs)

    @property
    def schedule(self) -> dict[str, ModelEntry]:
        """获取任务调度"""
        initial = update = False
        if self._initial_read:
            logger.debug('DatabaseScheduler: initial read')
            initial = update = True
            self._initial_read = False
        elif self.schedule_changed():
            logger.info('DatabaseScheduler: Schedule changed.')
            update = True

        if update:
            logger.debug('beat: Synchronizing schedule...')
            self.sync()
            self._schedule = run_await(self.get_all_task_schedulers)()
            # 计划已更改，使 Scheduler.tick 中的堆无效
            if not initial:
                self._heap = []
                self._heap_invalidated = True
            logger.debug(
                'Current schedule:\n%s',
                '\n'.join(repr(entry) for entry in self._schedule.values()),
            )

        # logger.debug(self._schedule)
        return self._schedule


async def extend_scheduler_lock(lock):
    """
    延长调度程序锁

    :param lock: 计划程序锁
    :return:
    """
    while True:
        await asyncio.sleep(DEFAULT_LOCK_INTERVAL)
        if lock:
            try:
                await lock.extend(DEFAULT_MAX_LOCK_TIMEOUT)
            except Exception as e:
                logger.error(f'Failed to extend lock: {e}')


@beat_init.connect
def acquire_distributed_beat_lock(sender=None, *args, **kwargs):
    """
    尝试在启动时获取锁

    :param sender: 接收方应响应的发送方
    :return:
    """
    scheduler = sender.scheduler
    if not scheduler.lock_key:
        return

    logger.debug('beat: Acquiring lock...')
    lock = redis_client.lock(
        scheduler.lock_key,
        timeout=DEFAULT_MAX_LOCK_TIMEOUT,
        sleep=scheduler.max_interval,
    )
    # overwrite redis-py's extend script
    # which will add additional timeout instead of extend to a new timeout
    lock.lua_extend = redis_client.register_script(LUA_EXTEND_TO_SCRIPT)
    run_await(lock.acquire)()
    logger.info('beat: Acquired lock')
    scheduler.lock = lock

    loop = asyncio.get_event_loop()
    loop.create_task(extend_scheduler_lock(scheduler.lock))
